{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "65f3c082-63ef-4f91-bfa5-3d2bf2d84c62",
   "metadata": {},
   "source": [
    "# Understand Standard Modules in a Detector"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71375de8-ba5c-49bb-bc9e-333eaa6553cc",
   "metadata": {},
   "source": [
    "A detector usually consists of:\n",
    "\n",
    "- a backbone, to generate multi-stage feature maps\n",
    "- a feature pyramid network, to fuse features of different stages\n",
    "- a dense prediction head, to predict object at different locations\n",
    "- (optinally) sparse head for two-stage and multi-stage detectors\n",
    "\n",
    "This image is from YOLO v4\n",
    "\n",
    "![](https://miro.medium.com/max/720/1*Z5GOPYFgh7_NTr7drt45mw.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b420800-f32c-4195-9a06-d3cc9d16f55f",
   "metadata": {},
   "source": [
    "In MMDetection, components are defined in `model` key in the config file.\n",
    "\n",
    "In this notebook, we investigate these modules interatively, and understand the connection between Python modules and config files. \n",
    "\n",
    "We use RetinaNet and FCOS as example.\n",
    "\n",
    "Architecture: \n",
    "\n",
    "<img src=\"https://production-media.paperswithcode.com/methods/Screen_Shot_2020-06-23_at_3.34.09_PM_SAg1OBo.png\" alt=\"Architecture\" style=\"width: 800px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bae3fd61-8371-4283-af8c-61644047ef4b",
   "metadata": {},
   "source": [
    "## Backbone\n",
    "\n",
    "We define a ResNet backbone that output feature map C2 to C5.\n",
    "\n",
    "Here index 0 indicates C2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8fe59c3d-b8ea-4f54-a307-330c99d3b9ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cine/miniconda3/envs/mmlab2/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/cine/miniconda3/envs/mmlab2/lib/python3.8/site-packages/mmengine/model/utils.py:138: UserWarning: Cannot import torch.fx, `merge_dict` is a simple function to merge multiple dicts\n",
      "  warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '\n"
     ]
    }
   ],
   "source": [
    "from mmdet.models.backbones import ResNet\n",
    "\n",
    "backbone = ResNet(depth=18, out_indices=(0,1,2,3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ab16a77-b8d8-459c-9240-119b565310a8",
   "metadata": {},
   "source": [
    "We forward a random tensor and check the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0b2c49cf-7925-478c-8207-e464f4fbe193",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "img = torch.rand(1,3,1000,600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eda9ab5d-204a-44bc-9431-b7dd21634f37",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbout = backbone(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4c112978-2159-41e0-b69f-bd45035484da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 64, 250, 150])\n",
      "torch.Size([1, 128, 125, 75])\n",
      "torch.Size([1, 256, 63, 38])\n",
      "torch.Size([1, 512, 32, 19])\n"
     ]
    }
   ],
   "source": [
    "for o in bbout:\n",
    "    print(o.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "952b11b3-49b5-4129-986c-a8f6f915cb38",
   "metadata": {},
   "source": [
    "These tensors are downsampled from 4x to 32x respectively. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55338895-cf92-4454-bb3a-7ffc57bff36d",
   "metadata": {},
   "source": [
    "## Neck\n",
    "\n",
    "We define a plain feature pyramid network to fuse different modules. \n",
    "\n",
    "`start_level=1` means index 1 from the backbone, which is C3. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "48abfc63-9d84-4ca6-8d8f-668962ae9763",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mmdet.models.necks import FPN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8e2c5b18-c558-481f-b147-fbf88c4cfc61",
   "metadata": {},
   "outputs": [],
   "source": [
    "neck = FPN(in_channels=[64, 128, 256, 512], start_level=1, out_channels=256, num_outs=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f0c96a1c-c9c1-4a1a-a5b4-ed38e546e720",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 256, 125, 75])\n",
      "torch.Size([1, 256, 63, 38])\n",
      "torch.Size([1, 256, 32, 19])\n",
      "torch.Size([1, 256, 16, 10])\n",
      "torch.Size([1, 256, 8, 5])\n"
     ]
    }
   ],
   "source": [
    "ncout = neck(bbout)\n",
    "\n",
    "for o in ncout:\n",
    "    print(o.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cda2a5b-072b-4d25-9212-42e8bce52697",
   "metadata": {},
   "source": [
    "The neck outputs fused features of all 5 levels, with same number of channels. \n",
    "Down-sampling rates ranges from 8x to 128x."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebab7af5-40ef-4371-8480-0582aa674878",
   "metadata": {},
   "source": [
    "## Anchor-free Heads\n",
    "\n",
    "An anchor-free head slides on the feature map and compute predictions directly (compared to anchor-based heads which compute prediction referenced to anchors)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9649790c-5227-4225-bf0c-15ba85895da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mmdet.models.dense_heads import FCOSHead"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e7c40443-5db1-430f-b7eb-0845ed727118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "02/08 16:02:52 - mmengine - WARNING - The \"model\" registry in mmdet did not set import location. Fallback to call `mmdet.utils.register_all_modules` instead.\n",
      "02/08 16:02:52 - mmengine - WARNING - The \"task util\" registry in mmdet did not set import location. Fallback to call `mmdet.utils.register_all_modules` instead.\n"
     ]
    }
   ],
   "source": [
    "head = FCOSHead(num_classes=20, in_channels=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "827cdcab-f690-440a-bb75-4c29824ee9d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "type: <class 'tuple'>\n",
      "length: 3\n"
     ]
    }
   ],
   "source": [
    "hout = head(ncout)\n",
    "\n",
    "print(F\"type: {type(hout)}\")\n",
    "print(F\"length: {len(hout)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0bc6a519-4124-4601-9022-dce1e90e50f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 20, 125, 75])\n",
      "torch.Size([1, 20, 63, 38])\n",
      "torch.Size([1, 20, 32, 19])\n",
      "torch.Size([1, 20, 16, 10])\n",
      "torch.Size([1, 20, 8, 5])\n",
      "torch.Size([1, 4, 125, 75])\n",
      "torch.Size([1, 4, 63, 38])\n",
      "torch.Size([1, 4, 32, 19])\n",
      "torch.Size([1, 4, 16, 10])\n",
      "torch.Size([1, 4, 8, 5])\n",
      "torch.Size([1, 1, 125, 75])\n",
      "torch.Size([1, 1, 63, 38])\n",
      "torch.Size([1, 1, 32, 19])\n",
      "torch.Size([1, 1, 16, 10])\n",
      "torch.Size([1, 1, 8, 5])\n"
     ]
    }
   ],
   "source": [
    "for res in hout:\n",
    "    for o in res:\n",
    "        print(o.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc9f367-8836-405c-94a5-b5f0d8aca11d",
   "metadata": {},
   "source": [
    "FCOS head predict:\n",
    "\n",
    "- Class probabilities of 20 dimension\n",
    "- Bounding box coodinates of 4 dimension \n",
    "- Centereness of 1 dimension \n",
    "\n",
    "at all positions on all levels of feature maps.\n",
    "\n",
    "\n",
    "During inference, post-processing like thresholding and NMS will produce final detection boxes. \n",
    "\n",
    "During training, predictions will be compared with ground-truth to produce loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d83b2855-b4a4-47a1-9e82-4db27d29c139",
   "metadata": {},
   "source": [
    "## Anchor-based Heads\n",
    "\n",
    "Anchor-based head is programmingly more complex as it contains an anchor generation module. Let's investigate this first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "from mmdet.core.anchor import AnchorGenerator\n",
    "# from mmdet.models.task_modules import AnchorGenerator # for v2.0\n",
    "\n",
    "ag = AnchorGenerator(\n",
    "            scales=[8],\n",
    "            ratios=[0.5, 1.0, 2.0],\n",
    "            strides=[4, 8, 16, 32, 64])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[-22.6274, -11.3137,  22.6274,  11.3137],\n",
      "        [-16.0000, -16.0000,  16.0000,  16.0000],\n",
      "        [-11.3137, -22.6274,  11.3137,  22.6274]]), tensor([[-45.2548, -22.6274,  45.2548,  22.6274],\n",
      "        [-32.0000, -32.0000,  32.0000,  32.0000],\n",
      "        [-22.6274, -45.2548,  22.6274,  45.2548]]), tensor([[-90.5097, -45.2548,  90.5097,  45.2548],\n",
      "        [-64.0000, -64.0000,  64.0000,  64.0000],\n",
      "        [-45.2548, -90.5097,  45.2548,  90.5097]]), tensor([[-181.0193,  -90.5097,  181.0193,   90.5097],\n",
      "        [-128.0000, -128.0000,  128.0000,  128.0000],\n",
      "        [ -90.5097, -181.0193,   90.5097,  181.0193]]), tensor([[-362.0387, -181.0193,  362.0387,  181.0193],\n",
      "        [-256.0000, -256.0000,  256.0000,  256.0000],\n",
      "        [-181.0193, -362.0387,  181.0193,  362.0387]])]\n"
     ]
    }
   ],
   "source": [
    "bbx = ag.single_level_grid_priors(featmap_size=(125,75), level_idx=1)\n",
    "bbx = ag.gen_base_anchors()\n",
    "print(bbx)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cine/miniconda3/envs/mmlab2/lib/python3.8/site-packages/mmengine/visualization/visualizer.py:163: UserWarning: `Visualizer` backend is not initialized because save_dir is None.\n",
      "  warnings.warn('`Visualizer` backend is not initialized '\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUaklEQVR4nO3dXWxb9f3H8c+xY8d5cEnKaBv4FwpltHR0lDIxqCj9s63VhMSDKExDXCBxsaveINgNIBAMkCak8XCBxB0qQ5o0kKCqQFq7/1ihLZSWNUnJaELcNVkS58FxHDt2/Hj+F0DAJMVxsZ18k/cL9YITn9/5OQlv7NPj33Fc1xUAWOJZ6AkAQLkIFwBzCBcAcwgXAHMIFwBzCBcAc+pKfJ1rJQAsFOdcX+AVFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAc+oWegL40pnoGR0IHZDrugs9FczBlatNF23S9ku3y3GchZ7Oske4FomO4Q5FkhHdfOnNCz0VzCGcCOtA7wFtv3T7Qk8FIlyLyvqV67X9Mv7DWIx6x3t1auTUQk8DX+EcFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHPqFnoC1riuq3gmrnwhX9FxE5mEMvmMoqloRcdFZcTSMaVyKUWno3LkVGxcx3EU9Afl9XgrNuZyQLjKFM/E9fu//V6NvkY5TuV+gXvHe5V38+oc6azYmKiciekJdY12KZvPVvTnHklG9NBND2nLmi0VG3M5IFxlyhfyavQ16plfPCOf11excfd371c6l9buTbsrNiYqJxQN6c8df9YTO56o6LivHn9VyWyyomMuB4TrPDiOI5/XJ7/XX7ExvY5XXo+3omOicnwen7yOVz6Pr6KvuHiLeH44OQ/AHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcs0s3n4meUcdwR82Pm8gk1Dveq/3d++V1Krfs7rGBY8oWsqr31ldszLmk82nF0/GqHqPW6uvqFfQHq3qMcCKs05HT2nd6X0XH7RjuUDwd1+jUaEXHnY+2YJt+dvHP5HHsvX4xG64DoQOKJCNav3J9TY+byWeUd/NK59IVXS88W8gqV8gpnU9XbMy5HDp7SB7How0XbqjqcWplMj2po/89qvs331/RteC/K5PPVOXnkyvklC1kq/5z/65kNqmDoYO6bs118ngJV824rqubL71Z2y/bXtPjRlNRdY50avem3RW9sUW9t17pfFq/+clvKjbmXFLZlNqCbdq1fldVj1Mr4URY0emo7tl0T1VvPNE73qux5Jju3XRvRQM5lhzTljVbtG3ttoqNOR+RZESdw3ZvhWcvtQCWPcIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcsx+yxjdyhZwG44PKFXIlHzuWHJPjOApFQzWYWfWNJccUm47pzMSZksuzOHK0unm1Gn2NNZodqoVwLQGD8UE9+vdHdXnL5SVXLmgfblfQH9QX41/UaHbVlcgk1DnSqddOvlYyXEPxId2x4Q7dvuH2Gs0O1UK4loBcIafLWy7XU7c+JUffH6697XvVFmzTzit21mh21RVOhPXCRy/oqf99qmS43v787Zqve4XqIFxLhOM4cuTMe62oai66V0vffh6lnpPjOJJb7RmhFjg5D8AcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcruNaJNL5tA6dPaRUNlX2vmPJMbUPt2tv+96Sjz3cf1gtgRYNxYfmPX6ukNN4alxuDS6CCvqDZX0kJ5aO6bORz/R6x+slL779dOhTFdyCpjJTZc9rZGpEE9MTZe+H6iBci0Q8HZfH8agt2Fb2vo7jKOgPzmvflkCLVjasLOs4oWhInSOduu3Ht5U9t3KcHjutwfhgWccJ1AXU5G/SmuY1Ja+cb51oVb6QP6/vcSafUftwe9n7oToI1yKy4cIN53WH6VA0pC/Gv9DOK3aWvHp8KD5U9p2sT4ZPKpFJ6LfX/LbsuZXj0NlD+nzs87LmFk6EdWLohHZesbPknayT2aRyhdx5fY87hzv1Qd8HZe+H6uAcFwBzCBcAcwgXAHMIFwBzODmPGa479+UOruvO/KnF8c91nKWyFA9+OMKFIn/t+qs6hjuKrokanhpWKBrSZHqyqsfui/VpLDWm/lh/0fYV9Su054Y9avA1VPX4sINwoUjHcId+feWvdUnwkpltXaNdOtx/WA9e92BVj31s4Jh6o72675r7ZrZlC1m99NFLms5NEy7MIFwo4sjRJcFLdHnr5TPbYumYVjasLNpWDf2T/YqlY0XHyeQz8nv9VT0u7OHkPABzCBcAc3iruMzkCjmFoiGdDJ+c9TXXdTU8Nayu0S7F0rGZ7V9/hvC7+7iuq+nc9Lzu5/htjuOo0dc467OFPZEe9cf6i46TzWc1mhzVqZFTCtYHZ401lhw7rw+mwzbCtcyMp8bVOdKpRCYx62uu6yoUDelw/2GtbFg5s30wPqjuSLcOhg4WPT6bz2p/z35du/pa1Xnm/6vUPtyuGy+5URc1XVS0PRQNaWRqpOg4+UJefbE+vf+f9+c8OT8xPaG+WF9NVq7A4kG4lhlXrm778W1zfmDadV1Npif14HUPFp0gPxk+qYOhg3pk2yNFj09lUxqeGtYzv3hGzf7m+R3fdfXsB89q99W7dfVFVxd97esPWf/u+t/NbMvkMwonwtpzwx61NrTOGu/sxFk9/LeH53VsLB2c4wJgDq+4ULbuSLeO9B9RNp9V12iX3uh4Q/V19fPe/19D/5LX8erjgY+1rmWddly2g6viURbChbId6T+iaCqqqy68Ss3+Zl0cvLisi0NX1K/Q6ubVavY3a9/pfdp+6XZ5ne9fSwv4NsKFklzXVTafnfnbu2w+q6suvErb1m7Te1+8p5vW3qQmX5NcuSq4hTnH8Mgz86rqw74PtbVtq1oDrTrSf0SpXEpex6t0Ll31z0NiaSBcKGk6N639Pfs1PDUs6cuPADX7m/XeF+/pcN9hPfmPJ+Xz+hROhDU6NaqALyBHzszf9OULeUlfrvDqOI6ODRxTKBpSoC6gE0Mn9OjfH5XH8ag/1q/rL75+wZ4n7CBcKClXyOna1dfqmV88I0l6o+MNXRy8WDetvUlP/uNJPbHjCfm9fr3177dUX1evXVcUL408GB/U3va9euyWxxSoC+j5w8/rzo13qiXQopc/fllP3/q0PI5H7/a8q67RroV4ijCGcGFe6jx1M5c81NfVq8HXoCZfk3xen/xev1746AUdGzgmv9evE4MnivaNp+PqGOnQ4//3uB6/5XH5vD411H25v9/rV5OvSV6Pt6wT/FjeCBfOmytX4URYb/37LR0bOKY1zWsUrA9q86rNRY8bT40rnAhrODGsN7veVNdol+7aeNfCTBpLAuHCeSu4BY1Ojaq+rl5+r1/B+qBa6lu0qmlV0eM8jkeNvkb5vD4F6gKKTcc0nZteoFljKSBc+EECvoB2XbFLJwZPaPOqzVrVtEq3b7i96DH9sX71jPeo3luvW9fdqhNDJ84xGjA/XDmPH6TU3aOBaiBc+EH4cDMWAuECYA7nuPCD5At5DcYHFU/HNZ4an7mQ9NsG44OaTE/K7/UrnAgrmU0u0GyxVBAunDfPVy/Y97bvVcdIh8KJsBp9jeoZ7yl63GR6Up8Ofao6T52S2aQG44PyeXwLMWUsEbxVxHlzHEcbLtygx255TNe3XS+/1y+f16d6b33RH7/XrzpPndauWKtHtj2iX17+S26AgR+EV1yYt7k+AO04jgJ1AT1+y+N6s+tNBeoCunXdrUWP+frt4SPbHtFlF1w277GBcyFcKMlxHLUPt+vZD56V9OV6WivqV+jDvg91bOCYnj/8vHxe35dr1U/HZl2n9fXbw7+c+osk6Z//+adGp0bV6GvU0f6jeu6D5+Q4jroj3bpy5ZU1f36wh3ChpEZfo2685Ebtvnq3JMnreLW6ebW2tm1VKBrSnRvvVENdg+7aeNc5r4j3eXwzbw9Hp0a1a/0uXRC4QCNTI7r76rvlcTw6Pnhc8Uy8Zs8LdhEulORxPLqo6aKZNeI/HvhYzf5mtQZaFagLqCXQoiZf07zHa/Q16oLABWoJtKi1oVUbf7RRXo9Xo8lRfT72ebWeBpYQwoWyrWtZp32n9+lI/xGdGDqhlz9+uayT7Uf7j2pkakStDa1a37qeZZtRNsKFsu24bIe2X7pdqVxKj/79UT1969NlveJ67oPndPfVd2vjjzbKcRw+NoSyES6UzXEceR2vvI5XHscjj+OR1zO/NeNd15XjOGXtA3wX4VqGTo+d1qGzh2Ztd11XfbE+HRs4pv7Jb65+74n0KBQNzdonnUurP9avd3veLWsRwO5It44PHtdocrRoe3u4XX2xvqLjZPNZDcQHdKT/yJx3sh5ODJd9J23YR7iWmaA/qMH44DlPgo+lxtQb7VUsHZvZ1h/r18jUyKx9XNfV9RdfX/Zyy1euvFLxTHzWeH2xPoUT4aLt+UJe46lx9Yz3qNHXOGusyfSkVjas5O3mMkO4lplGX6Nu+/Ft2rV+16yvua6r/li/7rvmvjnvZP3tO0xXw7nuZN0d6dYD1z4w552sw4mw/nT0T1WdFxYfPvIDwBzCBcAc3ipilmwhq0w+882/57PKF/JF26py3DmOk81nWawQsxAuFFlRv0IvffRS0QWlo8nRmRPn1TQQH9B4alzdke6Zba5cTeemVefhVxXf4LcBRfbcsGfW5w1PjZzS+/95X3tu2FPVYx/pP6Ke8R49cO0DRdu/fU9HQCJc+BbHcdTga1CDr6Foe7A+qAZfw5x/q1dJwfqgGn2NVT8O7OPkPABzCBcAcwgXAHMIFwBzODm/iEymJ8/rkoOx5JgSmYTCiXDJta1i6ZgCdYGyjjOWHNPE9ITOTpwte27lGE4Ml/09GJka0VRmSuFEuORqE9FUVHk3f17f40gqwvVkiwjhWiTq6+p19L9HFZ2Olr1vbDqmzpFOvfDRCyUf+9nIZ2ryN81aF/77pLIp9cX69PDfHi57buXIFXJa2bCyrM8eTmWmdHzwuF786MWS0e6J9MiVW3Sd2HxNTE8o4A2UvR+qg3AtEkF/UPdvvl/3bLqn7H3PTJzRaydf01P/+1TJx77e8brWNK/Rzit2lnWMWr3aKHeVh3AirBc/elHP/fI5eZzvP/Pxzul3lCvkZtbOL0coGtLrHa+XvR+qg3AtIo7jnNfiel8v5udxPCVfdThaWov4eT3eeS9MWO6ih9/dl6VzFg9OzgMwh3ABMIdwATCHcAEwh3ABMIdwATCHcAEwh+u4lgBHjobiQ3r787dLXsf16dCnap1oVTKbrNHsqiuaiqon0qN3Tr9T8gLUk+GT2rxqc41mhmoiXEvA6ubVumPDHUrn0yp1gXvBLShfyC+Zm6jm3bxcucoVciXDtXnVZv38f35eo5mhmgjXEtDoa9TtG26f12OnMlNqC7bNeV9Fi8KJsLoj3dp99e4l82kAlMY5LgDmEC4A5hAuAOYQLgDmEC4A5hAuAOYQLgDmEC4A5hAuAOYQLgDmEC4A5hAuAOaY/ZC1K1fhRFi94701PW4sHdPE9IRC0ZB8Hl/Fxg0nwsrkM1V/PiNTI8rkM+oc7qzqcWolkorM/DxKrQ7xQ/TF+jSeGldvtLeitykbnRrVwORAzX+PJ6YnlHfzNT1mJTmu+73roCzae44fOntIB3oPVPWXdS6pXEqHzh7Sr674lbxO5VYjOB05rVwhp59c9JOKjTmXiekJRVKRmn/fqsWVq4A3oLZgW1XvezieGlfHcId2rNtR0eN8MviJ1jSv0doVays25ny4crWuZZ0euPaBxbyqxjm/0WbDVWLeVROdjuoP//yD/rjzjxV9xbXv9D6l82ndu+neio2JyumN9mpv+9553S28HK988oq2rNmibWu3VXTccpRafHIBnXNiZt8qLtQ325Ezc+xqzGER/xIta85X/0iV/Rk5jjPzB/O3NN4vAFhWCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHPMLt28UBzHUSQZ0avHX63oTQY6hjuUK+Q0lhyr2JionNGpUX0y+Ile+eSVii6zfDJ8UlvbtlZsvOWCcJUp6A/qoZseUjKbrOi48XRc2UJWW9Zsqei4qIyByQGdjZ3VljVbKhqurW1bdc2qayo23nJBuMrk9XirEpfRqVGl8+kFvdsLzq13vFenRk5p29pt3NhiEeAcFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHPqFnoC+EYym1QkGVnoaWAOE9MTcuUu9DTwFcK1SLQF23QwdFCdw50LPRXMIe/m9dPVP13oaeArjut+7/9F+F9MjRTcgvKF/EJPA9/D43jkcTxyHGehp7JcnPMbTbgALFbnDBcn5wGYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYQ7gAmEO4AJhDuACYU1fi605NZgEAZeAVFwBzCBcAcwgXAHMIFwBzCBcAcwgXAHP+H5rRgrUXkbTEAAAAAElFTkSuQmCC\n"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mmengine.visualization import Visualizer\n",
    "import numpy as np\n",
    "\n",
    "vis = Visualizer(image = np.ones((1000,1000,3))*255)\n",
    "for b in bbx:\n",
    "    vis.draw_bboxes(b+500)\n",
    "vis.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Anchors of different size will be assigned to feature maps of different levels."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We then construct an anchor-based head based on this anchor generator."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [],
   "source": [
    "from mmdet.models.dense_heads import RetinaHead\n",
    "abhead = RetinaHead(in_channels=256,\n",
    "                    num_classes=20,\n",
    "                    anchor_generator=dict(\n",
    "                        type='AnchorGenerator',\n",
    "                        octave_base_scale=4,\n",
    "                        scales_per_octave=3,\n",
    "                        ratios=[0.5, 1.0, 2.0],\n",
    "                        strides=[8, 16, 32, 64, 128]))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "from mmdet.models.dense_heads import RPNHead\n",
    "abhead = RPNHead(in_channels=256,\n",
    "                 anchor_generator=dict(\n",
    "                    type='AnchorGenerator',\n",
    "                    scales=[8, 16, 32],\n",
    "                    ratios=[0.5, 1.0, 2.0],\n",
    "                    strides=[8, 16, 32, 64, 128]))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [
    "out = abhead(ncout)\n",
    "cls, bbx = out"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 9, 125, 75])\n",
      "torch.Size([1, 9, 63, 38])\n",
      "torch.Size([1, 9, 32, 19])\n",
      "torch.Size([1, 9, 16, 10])\n",
      "torch.Size([1, 9, 8, 5])\n"
     ]
    }
   ],
   "source": [
    "for o in cls:\n",
    "    print(o.shape)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 36, 125, 75])\n",
      "torch.Size([1, 36, 63, 38])\n",
      "torch.Size([1, 36, 32, 19])\n",
      "torch.Size([1, 36, 16, 10])\n",
      "torch.Size([1, 36, 8, 5])\n"
     ]
    }
   ],
   "source": [
    "for o in bbx:\n",
    "    print(o.shape)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "837d9a27-6b68-4f03-afe5-902710d4fc8d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "e678ac14dff038bc46ab45f1b54ebd32d0ce2b34611f4e1e107e2941d3b654ca"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
